#include "cell.h"
#include <qthread.h>
#include "swarch.h"
#ifdef __sw_slave__
#include <simd.h>
#include <qthread_slave.h>
#include "dma_macros_new.h"
#include "cal.h"
#endif
#define NPE_PACK 24

#define UP_ALIGN_PTR(ptr, alignment) ((char*)((((long)(ptr)) + (alignment - 1)) & ~(alignment - 1)))
#define UP_ALIGN_SIZE(size, alignment) ((((size) + (alignment) - 1)) & ~ ((alignment) - 1))
#define MAX_BUF_SIZE 49152

#ifdef __sw_slave__
size_t pack_cell_forward_x(char *buf, celldata_t *cell) {
  dma_init();
  char lbuf[MAX_BUF_SIZE];
  cellmeta_t meta;
  pe_get(&cell->basis, &meta, sizeof(cellmeta_t));
  dma_syn();
  size_t offset = UP_ALIGN_SIZE(16, 8);
  if (meta.natom)
  pe_get(cell->x, lbuf + offset, sizeof(*cell->x) * meta.natom);
  dma_syn();
  offset = UP_ALIGN_SIZE(offset + sizeof(*cell->x) * meta.natom, 32);
  *(long*)lbuf = offset;

  pe_put(buf, lbuf, offset);
  dma_syn();
  return offset;
}
size_t unpack_cell_forward_x(char *buf, celldata_t *cell) {
  dma_init();
  /*launch the unpack procedure with first 256 bytes*/
  char lbuf[MAX_BUF_SIZE];
  pe_get(buf, lbuf, 256);
  dma_syn();
  size_t cell_size = *(long*)lbuf;
  cellmeta_t meta;
  pe_get(&cell->basis, &meta, sizeof(cellmeta_t));
  if (cell_size > 256) {
    pe_get(buf + 256, lbuf + 256, cell_size - 256);
  }
  dma_syn();
  size_t offset = UP_ALIGN_SIZE(16, 8);
  if (meta.natom)
  pe_put(cell->x, lbuf + offset, sizeof(*cell->x) * meta.natom);
  dma_syn();
  return cell_size;
}
size_t pack_cell_forward_shake(char *buf, celldata_t *cell) {
  dma_init();
  char lbuf[MAX_BUF_SIZE];
  cellmeta_t meta;
  pe_get(&cell->basis, &meta, sizeof(cellmeta_t));
  dma_syn();
  size_t offset = UP_ALIGN_SIZE(16, 8);
  if (meta.natom)
  pe_get(cell->shake_tmp, lbuf + offset, sizeof(*cell->shake_tmp) * meta.natom);
  dma_syn();
  offset = UP_ALIGN_SIZE(offset + sizeof(*cell->shake_tmp) * meta.natom, 32);
  *(long*)lbuf = offset;

  pe_put(buf, lbuf, offset);
  dma_syn();
  return offset;
}
size_t unpack_cell_forward_shake(char *buf, celldata_t *cell) {
  dma_init();
  /*launch the unpack procedure with first 256 bytes*/
  char lbuf[MAX_BUF_SIZE];
  pe_get(buf, lbuf, 256);
  dma_syn();
  size_t cell_size = *(long*)lbuf;
  cellmeta_t meta;
  pe_get(&cell->basis, &meta, sizeof(cellmeta_t));
  if (cell_size > 256) {
    pe_get(buf + 256, lbuf + 256, cell_size - 256);
  }
  dma_syn();
  size_t offset = UP_ALIGN_SIZE(16, 8);
  if (meta.natom)
  pe_put(cell->shake_tmp, lbuf + offset, sizeof(*cell->shake_tmp) * meta.natom);
  dma_syn();
  return cell_size;
}
size_t pack_cell_reverse_f(char *buf, celldata_t *cell) {
  dma_init();
  char lbuf[MAX_BUF_SIZE];
  cellmeta_t meta;
  pe_get(&cell->basis, &meta, sizeof(cellmeta_t));
  dma_syn();
  size_t offset = UP_ALIGN_SIZE(16, 8);
  if (meta.natom)
  pe_get(cell->f, lbuf + offset, sizeof(*cell->f) * meta.natom);
  dma_syn();
  offset = UP_ALIGN_SIZE(offset + sizeof(*cell->f) * meta.natom, 32);
  *(long*)lbuf = offset;

  pe_put(buf, lbuf, offset);
  dma_syn();
  return offset;
}
size_t unpack_cell_reverse_f(char *buf, celldata_t *cell) {
  dma_init();
  /*launch the unpack procedure with first 256 bytes*/
  char lbuf[MAX_BUF_SIZE];
  pe_get(buf, lbuf, 256);
  dma_syn();
  size_t cell_size = *(long*)lbuf;
  cellmeta_t meta;
  pe_get(&cell->basis, &meta, sizeof(cellmeta_t));
  if (cell_size > 256) {
    pe_get(buf + 256, lbuf + 256, cell_size - 256);
  }
  dma_syn();
  
  size_t offset = UP_ALIGN_SIZE(16, 8);
  vec<real> fbuf[CELL_CAP];
  if (meta.natom)
  pe_get(cell->f, fbuf, sizeof(*cell->f) * meta.natom);
  dma_syn();
  vec<real> *fptr = (vec<real>*)(lbuf + offset);
  for (int i = 0; i < meta.natom; i ++) {
    vecaddv(fbuf[i], fbuf[i], fptr[i]);
  }
  if (meta.natom)
  pe_put(cell->f, fbuf, sizeof(*cell->f) * meta.natom);
  dma_syn();
  return cell_size;
}
#endif
size_t estimate_vec_real(celldata_t *cell) {
  size_t offset = UP_ALIGN_SIZE(16, 8);
  offset = UP_ALIGN_SIZE(offset + sizeof(vec<real>) * cell->natom, 32);
  return offset;
}
#define estimate_forward_x estimate_vec_real
#define estimate_reverse_f estimate_vec_real
#define estimate_forward_shake estimate_vec_real
void build_param_real(vec_pack_param_t *pm, cellgrid_t *grid, int xlo, int xhi, int ylo, int yhi, int zlo, int zhi) {
  pm->grid = grid;
  // pm->buf = buf;
  pm->xlo = xlo;
  pm->xhi = xhi;
  pm->ylo = ylo;
  pm->yhi = yhi;
  pm->zlo = zlo;
  pm->zhi = zhi;
  int xlen = xhi - xlo;
  int ylen = yhi - ylo;
  int zlen = zhi - zlo;

  int ntask_pe = (xlen * ylen * zlen + NPE_PACK - 1) / NPE_PACK;
  div_magic_t div_npack_magic;
  make_magic(&div_npack_magic, ntask_pe);
  size_t offset = 0;
  int mod = 0, div = 0;
  for (int i = xlo; i < xhi; i ++) {
    for (int j = ylo; j < yhi; j ++) {
      for (int k = zlo; k < zhi; k ++) {
        int idx = ((i - xlo) * ylen + j - ylo) * zlen + k - zlo;
        int div = MAGIC_DIV(idx, div_npack_magic);
        int mod = idx - div * ntask_pe;
        if (mod == 0) {
          pm->offset[div] = offset;
        }
        celldata_t *cell = get_cell_xyz(grid, i, j, k);
        size_t cell_size = estimate_vec_real(cell);
        offset += cell_size;
      }
    }
  }
  pm->total = offset;
  // return offset;
}
void build_pack_params_sw(sw_archdata_t *archdata, cellgrid_t *grid) {
  vec<int> *nlocal = &grid->nlocal;
  vec<int> *lo = &(grid->dim.lo);
  vec<int> *hi = &(grid->dim.hi);
  int nn = grid->nn;
  build_param_real(archdata->pack_params + PACK_FWD_NEG_Z, grid, 0, nlocal->x, 0, nlocal->y, 0, nn);
  build_param_real(archdata->pack_params + PACK_FWD_POS_Z, grid, 0, nlocal->x, 0, nlocal->y, nlocal->z - nn, nlocal->z);
  build_param_real(archdata->pack_params + PACK_FWD_NEG_Y, grid, 0, nlocal->x, 0, nn, lo->z, hi->z);
  build_param_real(archdata->pack_params + PACK_FWD_POS_Y, grid, 0, nlocal->x, nlocal->y - nn, nlocal->y, lo->z, hi->z);
  build_param_real(archdata->pack_params + PACK_FWD_NEG_X, grid, 0, nn, lo->y, hi->y, lo->z, hi->z);
  build_param_real(archdata->pack_params + PACK_FWD_POS_X, grid, nlocal->x - nn, nlocal->x, lo->y, hi->y, lo->z, hi->z);

  build_param_real(archdata->pack_params + UNPACK_FWD_NEG_Z, grid, 0, nlocal->x, 0, nlocal->y, lo->z, lo->z + nn);
  build_param_real(archdata->pack_params + UNPACK_FWD_POS_Z, grid, 0, nlocal->x, 0, nlocal->y, hi->z - nn, hi->z);
  build_param_real(archdata->pack_params + UNPACK_FWD_NEG_Y, grid, 0, nlocal->x, lo->y, lo->y + nn, lo->z, hi->z);
  build_param_real(archdata->pack_params + UNPACK_FWD_POS_Y, grid, 0, nlocal->x, hi->y - nn, hi->y, lo->z, hi->z);
  build_param_real(archdata->pack_params + UNPACK_FWD_NEG_X, grid, lo->x, lo->x + nn, lo->y, hi->y, lo->z, hi->z);
  build_param_real(archdata->pack_params + UNPACK_FWD_POS_X, grid, hi->x - nn, hi->x, lo->y, hi->y, lo->z, hi->z);

  build_param_real(archdata->pack_params + PACK_REV_NEG_X, grid, lo->x, lo->x + nn, lo->y, hi->y, lo->z, hi->z);
  build_param_real(archdata->pack_params + PACK_REV_POS_X, grid, hi->x - nn, hi->x, lo->y, hi->y, lo->z, hi->z);
  build_param_real(archdata->pack_params + PACK_REV_NEG_Y, grid, 0, nlocal->x, lo->y, lo->y + nn, lo->z, hi->z);
  build_param_real(archdata->pack_params + PACK_REV_POS_Y, grid, 0, nlocal->x, hi->y - nn, hi->y, lo->z, hi->z);
  build_param_real(archdata->pack_params + PACK_REV_NEG_Z, grid, 0, nlocal->x, 0, nlocal->y, lo->z, lo->z + nn);
  build_param_real(archdata->pack_params + PACK_REV_POS_Z, grid, 0, nlocal->x, 0, nlocal->y, hi->z - nn, hi->z);

  build_param_real(archdata->pack_params + UNPACK_REV_NEG_X, grid, 0, nn, lo->y, hi->y, lo->z, hi->z);
  build_param_real(archdata->pack_params + UNPACK_REV_POS_X, grid, nlocal->x - nn, nlocal->x, lo->y, hi->y, lo->z, hi->z);
  build_param_real(archdata->pack_params + UNPACK_REV_NEG_Y, grid, 0, nlocal->x, 0, nn, lo->z, hi->z);
  build_param_real(archdata->pack_params + UNPACK_REV_POS_Y, grid, 0, nlocal->x, nlocal->y - nn, nlocal->y, lo->z, hi->z);
  build_param_real(archdata->pack_params + UNPACK_REV_NEG_Z, grid, 0, nlocal->x, 0, nlocal->y, 0, nn);
  build_param_real(archdata->pack_params + UNPACK_REV_POS_Z, grid, 0, nlocal->x, 0, nlocal->y, nlocal->z - nn, nlocal->z);
}

#define __CAT__(x, y) x ## _ ## y
#define CAT(x, y) __CAT__(x, y)
#define SUFFIX forward_x
#include "pack_brick_sw.h"
#undef SUFFIX
#define SUFFIX forward_shake
#include "pack_brick_sw.h"
#undef SUFFIX
#define SUFFIX reverse_f
#include "pack_brick_sw.h"
#undef SUFFIX
